fix masking error when using mlp_bias=True causing NaN during gradien…#3699
Open
snehalv2002 wants to merge 1 commit intomainfrom
Open
fix masking error when using mlp_bias=True causing NaN during gradien…#3699snehalv2002 wants to merge 1 commit intomainfrom
snehalv2002 wants to merge 1 commit intomainfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
RissyRan
reviewed
Apr 20, 2026
Collaborator
RissyRan
left a comment
There was a problem hiding this comment.
Thanks! Could you have a logits correctness check with EP>1 using GPT-OSS?
RissyRan
reviewed
Apr 20, 2026
| layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") | ||
| if self.config.mlp_bias: | ||
| layer_w0 = layer_w0 + w0_bias | ||
| layer_w0 = jnp.where(mask[:, None], layer_w0, 0) |
Collaborator
There was a problem hiding this comment.
Is default mask defined without EP sharding?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Bug fix for b/497864549: [XL ML] NaN training loss for GPT-OSS SFT.
GPT-OSS is currently experiencing NaN training loss after one step when using
expert_parallelism > 1. The bug was caused by lack of masking in the mlp bias inside of expert computation. Since no other models use the mlp bias they didn't experience this issue.Issue
Currently, when
expert_parallelism > 1we introduce a buffer to store the output ofragged_all_to_all, which may contain padding along the token axis. If padding values are not masked after adding the MLP bias, JAX will include them in the gradient computation.jnp.whereallows us to disconnect the padding values from the backwards graph during bias gradient calculation.Tests
Ran
pre_train.trainonGPT_OSSwithexpert_parallelism > 1: https://paste.googleplex.com/6279941382602752Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.